{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SimpleITK Spatial Transformations <a href=\"https://mybinder.org/v2/gh/InsightSoftwareConsortium/SimpleITK-Notebooks/master?filepath=Python%2F22_Transforms.ipynb\"><img style=\"float: right;\" src=\"https://mybinder.org/badge_logo.svg\"></a>\n",
    "\n",
    "\n",
    "**Summary:**\n",
    "\n",
    "1. Points are represented by vector-like data types: Tuple, Numpy array, List.\n",
    "2. Matrices are represented by vector-like data types in row major order.\n",
    "3. Default transformation initialization as the identity transform.\n",
    "4. Angles specified in radians, distances specified in unknown but consistent units (nm,mm,m,km...).\n",
    "5. All global transformations **except translation** are of the form:\n",
    "$$T(\\mathbf{x}) = A(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c}$$\n",
    "\n",
    "   Nomenclature (when printing your transformation):\n",
    "\n",
    "   * Matrix: the matrix $A$\n",
    "   * Center: the point $\\mathbf{c}$\n",
    "   * Translation: the vector $\\mathbf{t}$\n",
    "   * Offset: $\\mathbf{t} + \\mathbf{c} - A\\mathbf{c}$\n",
    "6. Bounded transformations, BSplineTransform and DisplacementFieldTransform, behave as the identity transform outside the defined bounds.\n",
    "7. DisplacementFieldTransform:\n",
    "   * Initializing the DisplacementFieldTransform using an image requires that the image's pixel type be sitk.sitkVectorFloat64.\n",
    "   * Initializing the DisplacementFieldTransform using an image will \"clear out\" your image (your alias to the image will point to an empty, zero sized, image).\n",
    "8. Composite transformations are applied in stack order (first added, last applied).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformation Types\n",
    "This notebook introduces the transformation types supported by SimpleITK and illustrates how to \"promote\" transformations from a lower to higher parameter space (e.g. 3D translation to 3D rigid).  \n",
    "\n",
    "\n",
    "| Class Name | Details|\n",
    "|:-------------|:---------|\n",
    "|[TranslationTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1TranslationTransform.html) | 2D or 3D, translation|\n",
    "|[VersorTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1VersorTransform.html)| 3D, rotation represented by a versor|\n",
    "|[VersorRigid3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1VersorRigid3DTransform.html)|3D, rigid transformation with rotation represented by a versor|\n",
    "|[Euler2DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Euler2DTransform.html)| 2D, rigid transformation with rotation represented by a Euler angle|\n",
    "|[Euler3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Euler3DTransform.html)| 3D, rigid transformation with rotation represented by Euler angles|\n",
    "|[Similarity2DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Similarity2DTransform.html)| 2D, composition of isotropic scaling and rigid transformation with rotation represented by a Euler angle|\n",
    "|[Similarity3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Similarity3DTransform.html) | 3D, composition of isotropic scaling and rigid transformation with rotation represented by a versor|\n",
    "|[ScaleTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ScaleTransform.html)|2D or 3D, anisotropic scaling|\n",
    "|[ScaleVersor3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ScaleVersor3DTransform.html)| 3D, rigid transformation and anisotropic scale is **added** to the rotation matrix part (not composed as one would expect)|\n",
    "|[ScaleSkewVersor3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ScaleSkewVersor3DTransform.html#details)|3D, rigid transformation with anisotropic scale and skew matrices **added** to the rotation matrix part (not composed as one would expect) |\n",
    "|[ComposeScaleSkewVersor3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ComposeScaleSkewVersor3DTransform.html)| 3D, a composition of rotation $R$, scaling $S$, and shearing $K$, $A=RSK$ in addition to translation. |\n",
    "|[AffineTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1AffineTransform.html)| 2D or 3D, affine transformation|\n",
    "|[BSplineTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1BSplineTransform.html)|2D or 3D, deformable transformation represented by a sparse regular grid of control points |\n",
    "|[DisplacementFieldTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1DisplacementFieldTransform.html)| 2D or 3D, deformable transformation represented as a dense regular grid of vectors|\n",
    "|[CompositeTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1CompositeTransform.html)| 2D or 3D, stack of transformations concatenated via composition, last added, first applied|\n",
    "|[Transform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Transform.html#details) | 2D or 3D, parent/super-class for all transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import interact, fixed\n",
    "\n",
    "OUTPUT_DIR = \"Output\"\n",
    "\n",
    "print(sitk.Version())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Points in SimpleITK"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utility functions\n",
    "\n",
    "A number of functions that deal with point data in a uniform manner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def point2str(point, precision=1):\n",
    "    \"\"\"\n",
    "    Format a point for printing, based on specified precision with trailing zeros. Uniform printing for vector-like data\n",
    "    (tuple, numpy array, list).\n",
    "\n",
    "    Args:\n",
    "        point (vector-like): nD point with floating point coordinates.\n",
    "        precision (int): Number of digits after the decimal point.\n",
    "    Return:\n",
    "        String represntation of the given point \"xx.xxx yy.yyy zz.zzz...\".\n",
    "    \"\"\"\n",
    "    return \" \".join(f\"{c:.{precision}f}\" for c in point)\n",
    "\n",
    "\n",
    "def uniform_random_points(bounds, num_points):\n",
    "    \"\"\"\n",
    "    Generate random (uniform withing bounds) nD point cloud. Dimension is based on the number of pairs in the bounds input.\n",
    "\n",
    "    Args:\n",
    "        bounds (list(tuple-like)): list where each tuple defines the coordinate bounds.\n",
    "        num_points (int): number of points to generate.\n",
    "\n",
    "    Returns:\n",
    "        list containing num_points numpy arrays whose coordinates are within the given bounds.\n",
    "    \"\"\"\n",
    "    internal_bounds = [sorted(b) for b in bounds]\n",
    "    # Generate rows for each of the coordinates according to the given bounds, stack into an array,\n",
    "    # and split into a list of points.\n",
    "    mat = np.vstack(\n",
    "        [np.random.uniform(b[0], b[1], num_points) for b in internal_bounds]\n",
    "    )\n",
    "    return list(mat[: len(bounds)].T)\n",
    "\n",
    "\n",
    "def target_registration_errors(tx, point_list, reference_point_list):\n",
    "    \"\"\"\n",
    "    Distances between points transformed by the given transformation and their\n",
    "    location in another coordinate system. When the points are only used to evaluate\n",
    "    registration accuracy (not used in the registration) this is the target registration\n",
    "    error (TRE).\n",
    "    \"\"\"\n",
    "    return [\n",
    "        np.linalg.norm(np.array(tx.TransformPoint(p)) - np.array(p_ref))\n",
    "        for p, p_ref in zip(point_list, reference_point_list)\n",
    "    ]\n",
    "\n",
    "\n",
    "def print_transformation_differences(tx1, tx2):\n",
    "    \"\"\"\n",
    "    Check whether two transformations are \"equivalent\" in an arbitrary spatial region\n",
    "    either 3D or 2D, [x=(-10,10), y=(-100,100), z=(-1000,1000)]. This is just a sanity check,\n",
    "    as we are just looking at the effect of the transformations on a random set of points in\n",
    "    the region.\n",
    "    \"\"\"\n",
    "    if tx1.GetDimension() == 2 and tx2.GetDimension() == 2:\n",
    "        bounds = [(-10, 10), (-100, 100)]\n",
    "    elif tx1.GetDimension() == 3 and tx2.GetDimension() == 3:\n",
    "        bounds = [(-10, 10), (-100, 100), (-1000, 1000)]\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            \"Transformation dimensions mismatch, or unsupported transformation dimensionality\"\n",
    "        )\n",
    "    num_points = 10\n",
    "    point_list = uniform_random_points(bounds, num_points)\n",
    "    tx1_point_list = [tx1.TransformPoint(p) for p in point_list]\n",
    "    differences = target_registration_errors(tx2, point_list, tx1_point_list)\n",
    "    print(\n",
    "        tx1.GetName()\n",
    "        + \"-\"\n",
    "        + tx2.GetName()\n",
    "        + f\":\\tminDifference: {min(differences):.2f} maxDifference: {max(differences):.2f}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In SimpleITK points can be represented by any vector-like data type. In Python these include Tuple, Numpy array, and List. In general Python will treat these data types differently, as illustrated by the print function below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SimpleITK points represented by vector-like data structures.\n",
    "point_tuple = (9.0, 10.531, 11.8341)\n",
    "point_np_array = np.array([9.0, 10.531, 11.8341])\n",
    "point_list = [9.0, 10.531, 11.8341]\n",
    "\n",
    "print(point_tuple)\n",
    "print(point_np_array)\n",
    "print(point_list)\n",
    "\n",
    "# Uniform printing with specified precision.\n",
    "precision = 2\n",
    "print(point2str(point_tuple, precision))\n",
    "print(point2str(point_np_array, precision))\n",
    "print(point2str(point_list, precision))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Global Transformations\n",
    "All global transformations <i>except translation</i> are of the form:\n",
    "$$T(\\mathbf{x}) = A(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c}$$\n",
    "\n",
    "In ITK speak (when printing your transformation):\n",
    "<ul>\n",
    "<li>Matrix: the matrix $A$</li>\n",
    "<li>Center: the point $\\mathbf{c}$</li>\n",
    "<li>Translation: the vector $\\mathbf{t}$</li>\n",
    "<li>Offset: $\\mathbf{t} + \\mathbf{c} - A\\mathbf{c}$</li>\n",
    "</ul>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TranslationTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A 3D translation. Note that you need to specify the dimensionality, as the sitk TranslationTransform\n",
    "# represents both 2D and 3D translations.\n",
    "dimension = 3\n",
    "offset = (1, 2, 3)  # offset can be any vector-like data\n",
    "translation = sitk.TranslationTransform(dimension, offset)\n",
    "print(translation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transform a point and use the inverse transformation to get the original back.\n",
    "point = [10, 11, 12]\n",
    "transformed_point = translation.TransformPoint(point)\n",
    "translation_inverse = translation.GetInverse()\n",
    "print(\n",
    "    \"original point: \" + point2str(point) + \"\\n\"\n",
    "    \"transformed point: \" + point2str(transformed_point) + \"\\n\"\n",
    "    \"back to original: \"\n",
    "    + point2str(translation_inverse.TransformPoint(transformed_point))\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Euler2DTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "point = [10, 11]\n",
    "rotation2D = sitk.Euler2DTransform()\n",
    "rotation2D.SetTranslation((7.2, 8.4))\n",
    "rotation2D.SetAngle(np.pi / 2)\n",
    "print(\n",
    "    \"original point: \" + point2str(point) + \"\\n\"\n",
    "    \"transformed point: \" + point2str(rotation2D.TransformPoint(point))\n",
    ")\n",
    "\n",
    "# Change the center of rotation so that it coincides with the point we want to\n",
    "# transform, why is this a unique configuration?\n",
    "rotation2D.SetCenter(point)\n",
    "print(\n",
    "    \"original point: \" + point2str(point) + \"\\n\"\n",
    "    \"transformed point: \" + point2str(rotation2D.TransformPoint(point))\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VersorTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Rotation only, parametrized by Versor (vector part of unit quaternion),\n",
    "# quaternion defined by rotation of theta around axis n:\n",
    "#  q = [n*sin(theta/2), cos(theta/2)]\n",
    "\n",
    "# 180 degree rotation around z axis\n",
    "\n",
    "# Use a versor:\n",
    "rotation1 = sitk.VersorTransform([0, 0, 1, 0])\n",
    "\n",
    "# Use axis-angle:\n",
    "rotation2 = sitk.VersorTransform((0, 0, 1), np.pi)\n",
    "\n",
    "# Use a matrix:\n",
    "rotation3 = sitk.VersorTransform()\n",
    "rotation3.SetMatrix([-1, 0, 0, 0, -1, 0, 0, 0, 1])\n",
    "\n",
    "point = (10, 100, 1000)\n",
    "\n",
    "p1 = rotation1.TransformPoint(point)\n",
    "p2 = rotation2.TransformPoint(point)\n",
    "p3 = rotation3.TransformPoint(point)\n",
    "\n",
    "print(\n",
    "    \"Points after transformation:\\np1=\"\n",
    "    + str(p1)\n",
    "    + \"\\np2=\"\n",
    "    + str(p2)\n",
    "    + \"\\np3=\"\n",
    "    + str(p3)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We applied the \"same\" transformation to the same point, so why are the results slightly different for the second initialization method?\n",
    "  \n",
    "This is where theory meets practice. Using the axis-angle initialization method involves trigonometric functions which on a fixed precision machine lead to these slight differences. In many cases this is not an issue, but it is something to remember. From here on we will sweep it under the rug (printing with a more reasonable precision). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Translation to Rigid [3D]\n",
    "Copy the translational component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dimension = 3\n",
    "t = (1, 2, 3)\n",
    "translation = sitk.TranslationTransform(dimension, t)\n",
    "\n",
    "# Only need to copy the translational component.\n",
    "rigid_euler = sitk.Euler3DTransform()\n",
    "rigid_euler.SetTranslation(translation.GetOffset())\n",
    "\n",
    "rigid_versor = sitk.VersorRigid3DTransform()\n",
    "rigid_versor.SetTranslation(translation.GetOffset())\n",
    "\n",
    "# Sanity check to make sure the transformations are equivalent.\n",
    "bounds = [(-10, 10), (-100, 100), (-1000, 1000)]\n",
    "num_points = 10\n",
    "point_list = uniform_random_points(bounds, num_points)\n",
    "transformed_point_list = [translation.TransformPoint(p) for p in point_list]\n",
    "\n",
    "# Draw the original and transformed points, include the label so that we\n",
    "# can modify the plots without requiring explicit changes to the legend.\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection=\"3d\")\n",
    "orig = ax.scatter(\n",
    "    list(np.array(point_list).T)[0],\n",
    "    list(np.array(point_list).T)[1],\n",
    "    list(np.array(point_list).T)[2],\n",
    "    marker=\"o\",\n",
    "    color=\"blue\",\n",
    "    label=\"Original points\",\n",
    ")\n",
    "transformed = ax.scatter(\n",
    "    list(np.array(transformed_point_list).T)[0],\n",
    "    list(np.array(transformed_point_list).T)[1],\n",
    "    list(np.array(transformed_point_list).T)[2],\n",
    "    marker=\"^\",\n",
    "    color=\"red\",\n",
    "    label=\"Transformed points\",\n",
    ")\n",
    "plt.legend(loc=(0.0, 1.0))\n",
    "\n",
    "euler_errors = target_registration_errors(\n",
    "    rigid_euler, point_list, transformed_point_list\n",
    ")\n",
    "versor_errors = target_registration_errors(\n",
    "    rigid_versor, point_list, transformed_point_list\n",
    ")\n",
    "\n",
    "print(f\"Euler\\tminError: {min(euler_errors):.2f} maxError: {max(euler_errors):.2f}\")\n",
    "print(f\"Versor\\tminError: {min(versor_errors):.2f} maxError: {max(versor_errors):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rotation to Rigid [3D]\n",
    "Copy the matrix or versor and <b>center of rotation</b>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotationCenter = (10, 10, 10)\n",
    "rotation = sitk.VersorTransform([0, 0, 1, 0], rotationCenter)\n",
    "\n",
    "rigid_euler = sitk.Euler3DTransform()\n",
    "rigid_euler.SetMatrix(rotation.GetMatrix())\n",
    "rigid_euler.SetCenter(rotation.GetCenter())\n",
    "\n",
    "rigid_versor = sitk.VersorRigid3DTransform()\n",
    "rigid_versor.SetRotation(rotation.GetVersor())\n",
    "# rigid_versor.SetCenter(rotation.GetCenter()) #intentional error\n",
    "\n",
    "# Sanity check to make sure the transformations are equivalent.\n",
    "bounds = [(-10, 10), (-100, 100), (-1000, 1000)]\n",
    "num_points = 10\n",
    "point_list = uniform_random_points(bounds, num_points)\n",
    "transformed_point_list = [rotation.TransformPoint(p) for p in point_list]\n",
    "\n",
    "euler_errors = target_registration_errors(\n",
    "    rigid_euler, point_list, transformed_point_list\n",
    ")\n",
    "versor_errors = target_registration_errors(\n",
    "    rigid_versor, point_list, transformed_point_list\n",
    ")\n",
    "\n",
    "# Draw the points transformed by the original transformation and after transformation\n",
    "# using the incorrect transformation, illustrate the effect of center of rotation.\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "incorrect_transformed_point_list = [rigid_versor.TransformPoint(p) for p in point_list]\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection=\"3d\")\n",
    "orig = ax.scatter(\n",
    "    list(np.array(transformed_point_list).T)[0],\n",
    "    list(np.array(transformed_point_list).T)[1],\n",
    "    list(np.array(transformed_point_list).T)[2],\n",
    "    marker=\"o\",\n",
    "    color=\"blue\",\n",
    "    label=\"Rotation around specific center\",\n",
    ")\n",
    "transformed = ax.scatter(\n",
    "    list(np.array(incorrect_transformed_point_list).T)[0],\n",
    "    list(np.array(incorrect_transformed_point_list).T)[1],\n",
    "    list(np.array(incorrect_transformed_point_list).T)[2],\n",
    "    marker=\"^\",\n",
    "    color=\"red\",\n",
    "    label=\"Rotation around origin\",\n",
    ")\n",
    "plt.legend(loc=(0.0, 1.0))\n",
    "\n",
    "print(f\"Euler\\tminError: {min(euler_errors):.2f} maxError: {max(euler_errors):.2f}\")\n",
    "print(f\"Versor\\tminError: {min(versor_errors):.2f} maxError: {max(versor_errors):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Similarity [2D]\n",
    "\n",
    "When the center of the similarity transformation is not at the origin the effect of the transformation is not what most of us expect. This is readily visible if we limit the transformation to scaling: $T(\\mathbf{x}) = s\\mathbf{x}-s\\mathbf{c} + \\mathbf{c}$. Changing the transformation's center results in scale + translation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_center_effect(x, y, tx, point_list, xlim, ylim):\n",
    "    tx.SetCenter((x, y))\n",
    "    transformed_point_list = [tx.TransformPoint(p) for p in point_list]\n",
    "\n",
    "    plt.scatter(\n",
    "        list(np.array(transformed_point_list).T)[0],\n",
    "        list(np.array(transformed_point_list).T)[1],\n",
    "        marker=\"^\",\n",
    "        color=\"red\",\n",
    "        label=\"transformed points\",\n",
    "    )\n",
    "    plt.scatter(\n",
    "        list(np.array(point_list).T)[0],\n",
    "        list(np.array(point_list).T)[1],\n",
    "        marker=\"o\",\n",
    "        color=\"blue\",\n",
    "        label=\"original points\",\n",
    "    )\n",
    "    plt.xlim(xlim)\n",
    "    plt.ylim(ylim)\n",
    "    plt.legend(loc=(0.25, 1.01))\n",
    "\n",
    "\n",
    "# 2D square centered on (0,0)\n",
    "points = [\n",
    "    np.array((-1.0, -1.0)),\n",
    "    np.array((-1.0, 1.0)),\n",
    "    np.array((1.0, 1.0)),\n",
    "    np.array((1.0, -1.0)),\n",
    "]\n",
    "\n",
    "# Scale by 2\n",
    "similarity = sitk.Similarity2DTransform()\n",
    "similarity.SetScale(2)\n",
    "\n",
    "interact(\n",
    "    display_center_effect,\n",
    "    x=(-10, 10),\n",
    "    y=(-10, 10),\n",
    "    tx=fixed(similarity),\n",
    "    point_list=fixed(points),\n",
    "    xlim=fixed((-10, 10)),\n",
    "    ylim=fixed((-10, 10)),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rigid to Similarity [3D]\n",
    "Copy the translation, center, and matrix or versor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotation_center = (100, 100, 100)\n",
    "theta_x = 0.0\n",
    "theta_y = 0.0\n",
    "theta_z = np.pi / 2.0\n",
    "translation = (1, 2, 3)\n",
    "\n",
    "rigid_euler = sitk.Euler3DTransform(\n",
    "    rotation_center, theta_x, theta_y, theta_z, translation\n",
    ")\n",
    "\n",
    "similarity = sitk.Similarity3DTransform()\n",
    "similarity.SetMatrix(rigid_euler.GetMatrix())\n",
    "similarity.SetTranslation(rigid_euler.GetTranslation())\n",
    "similarity.SetCenter(rigid_euler.GetCenter())\n",
    "\n",
    "# Apply the transformations to the same set of random points and compare the results\n",
    "# (see utility functions at top of notebook).\n",
    "print_transformation_differences(rigid_euler, similarity)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compose Scale Skew Versor\n",
    "\n",
    "Composition of rotation  $R$ , scaling  $S$ , and shearing  $K$ in addition to translation:\n",
    "$$\n",
    "T(x)=RSK(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c},\\;\\; \\textrm{where } S = \\left[\\begin{array}{ccc} s_0 & 0 & 0 \\\\ 0 & s_1 & 0 \\\\ 0 & 0 & s_2 \\end{array}\\right]\\;\\; \\textrm{and } K = \\left[\\begin{array}{ccc} 1 & k_0 & k_1 \\\\ 0 & 1 & k_2 \\\\ 0 & 0 & 1 \\end{array}\\right]$$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotation_center = [100, 100, 100]\n",
    "axis = [0, 0, 1]\n",
    "angle = np.pi / 2.0\n",
    "translation = [1, 2, 3]\n",
    "scale_factors = [3.14, 1.59, 2.65]\n",
    "skew = [4, 5, 6]\n",
    "compose_scale_skew_rigid1 = sitk.ComposeScaleSkewVersor3DTransform(\n",
    "    scale_factors, skew, axis, angle, translation, rotation_center\n",
    ")\n",
    "\n",
    "# The versor is n*sin(theta/2) for a unit norm axis\n",
    "versor = [a * np.sin(angle / 2.0) / np.linalg.norm(axis) for a in axis]\n",
    "compose_scale_skew_rigid2 = sitk.ComposeScaleSkewVersor3DTransform()\n",
    "# Parameter order is versor, translation, scale, skew\n",
    "compose_scale_skew_rigid2.SetParameters(versor + translation + scale_factors + skew)\n",
    "\n",
    "# Compare the two transformations, their parameters and their effect on a set of\n",
    "# random points (utility function top of notebook)\n",
    "print(f\"Transform1 parameters: {compose_scale_skew_rigid1.GetParameters()}\")\n",
    "print(f\"Transform2 parameters: {compose_scale_skew_rigid2.GetParameters()}\")\n",
    "print_transformation_differences(compose_scale_skew_rigid1, compose_scale_skew_rigid2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why don't the two transformations have the same effect on the point set even though their parameters are the same? What parameters did we forget to set for the second transformation?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Similarity to Affine [3D]\n",
    "Copy the translation, center and matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotation_center = (100, 100, 100)\n",
    "axis = (0, 0, 1)\n",
    "angle = np.pi / 2.0\n",
    "translation = (1, 2, 3)\n",
    "scale_factor = 2.0\n",
    "similarity = sitk.Similarity3DTransform(\n",
    "    scale_factor, axis, angle, translation, rotation_center\n",
    ")\n",
    "\n",
    "affine = sitk.AffineTransform(3)\n",
    "affine.SetMatrix(similarity.GetMatrix())\n",
    "affine.SetTranslation(similarity.GetTranslation())\n",
    "affine.SetCenter(similarity.GetCenter())\n",
    "\n",
    "# Apply the transformations to the same set of random points and compare the results\n",
    "# (see utility functions at top of notebook).\n",
    "print_transformation_differences(similarity, affine)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scale Transform\n",
    "\n",
    "Just as the case was for the similarity transformation above, when the transformations center is not at the origin, instead of a pure anisotropic scaling we also have translation ($T(\\mathbf{x}) = \\mathbf{s}^T\\mathbf{x}-\\mathbf{s}^T\\mathbf{c} + \\mathbf{c}$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2D square centered on (0,0).\n",
    "points = [\n",
    "    np.array((-1.0, -1.0)),\n",
    "    np.array((-1.0, 1.0)),\n",
    "    np.array((1.0, 1.0)),\n",
    "    np.array((1.0, -1.0)),\n",
    "]\n",
    "\n",
    "# Scale by half in x and 2 in y.\n",
    "scale = sitk.ScaleTransform(2, (0.5, 2))\n",
    "\n",
    "# Interactively change the location of the center.\n",
    "interact(\n",
    "    display_center_effect,\n",
    "    x=(-10, 10),\n",
    "    y=(-10, 10),\n",
    "    tx=fixed(scale),\n",
    "    point_list=fixed(points),\n",
    "    xlim=fixed((-10, 10)),\n",
    "    ylim=fixed((-10, 10)),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scale Versor\n",
    "\n",
    "This is not what you would expect from the name (composition of anisotropic scaling and rigid). This is:\n",
    "$$T(x) = (R+S)(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c},\\;\\; \\textrm{where } S= \\left[\\begin{array}{ccc} s_0-1 & 0 & 0 \\\\ 0 & s_1-1 & 0 \\\\ 0 & 0 & s_2-1 \\end{array}\\right]$$ \n",
    "\n",
    "There is no natural way of \"promoting\" the similarity transformation to this transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scales = (0.5, 0.7, 0.9)\n",
    "translation = (1, 2, 3)\n",
    "axis = (0, 0, 1)\n",
    "angle = 0.0\n",
    "scale_versor = sitk.ScaleVersor3DTransform(scales, axis, angle, translation)\n",
    "print(scale_versor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scale Skew Versor\n",
    "\n",
    "Again, not what you expect based on the name, this is not a composition of transformations. This is:\n",
    "$$T(x) = (R+S+K)(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c},\\;\\; \\textrm{where } S = \\left[\\begin{array}{ccc} s_0-1 & 0 & 0 \\\\ 0 & s_1-1 & 0 \\\\ 0 & 0 & s_2-1 \\end{array}\\right]\\;\\; \\textrm{and } K = \\left[\\begin{array}{ccc} 0 & k_0 & k_1 \\\\ k_2 & 0 & k_3 \\\\ k_4 & k_5 & 0 \\end{array}\\right]$$ \n",
    "\n",
    "In practice this is an over-parametrized version of the affine transform, 15 (scale, skew, versor, translation) vs. 12 parameters (matrix, translation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scale = (2, 2.1, 3)\n",
    "skew = np.linspace(\n",
    "    start=0.0, stop=1.0, num=6\n",
    ")  # six equally spaced values in[0,1], an arbitrary choice\n",
    "translation = (1, 2, 3)\n",
    "versor = (0, 0, 0, 1.0)\n",
    "scale_skew_versor = sitk.ScaleSkewVersor3DTransform(scale, skew, versor, translation)\n",
    "print(scale_skew_versor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modify transform center without changing the transformation effect\n",
    "\n",
    "Given a transformation $T_0$ with center $\\mathbf{c_0}$ we want to change the center to $\\mathbf{c_1}$ without changing the transformation's effect. That is $\\forall\\mathbf{x},\\;T_0(\\mathbf{x})=T_1(\\mathbf{x})$.\n",
    "\n",
    "With some simple arithmetic we see that for $T_1$ we need to set:\n",
    "* $A = A_0$\n",
    "* $\\mathbf{c}=\\mathbf{c_1}$\n",
    "* $\\mathbf{t}=A(\\mathbf{c_1}-\\mathbf{c_0}) + \\mathbf{t_0} + \\mathbf{c_0}- \\mathbf{c_1}$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_translation = np.array(rigid_euler.GetTranslation())\n",
    "old_matrix = np.array(rigid_euler.GetMatrix()).reshape((3, 3))\n",
    "old_center = np.array(rigid_euler.GetCenter())\n",
    "\n",
    "rigid_euler2 = sitk.Euler3DTransform()\n",
    "new_center = np.array([2, 4, 8])\n",
    "new_translation = (\n",
    "    old_translation + old_center + old_matrix.dot(new_center - old_center) - new_center\n",
    ")\n",
    "\n",
    "rigid_euler2.SetMatrix(old_matrix.ravel())\n",
    "rigid_euler2.SetTranslation(new_translation.tolist())\n",
    "rigid_euler2.SetCenter(new_center.tolist())\n",
    "\n",
    "pnt = [16, 32, 64]\n",
    "print(rigid_euler.TransformPoint(pnt))\n",
    "print(rigid_euler2.TransformPoint(pnt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bounded Transformations\n",
    "\n",
    "SimpleITK supports two types of bounded non-rigid transformations, BSplineTransform (sparse representation) and \tDisplacementFieldTransform (dense representation).\n",
    "\n",
    "Transforming a point that is outside the bounds will return the original point - identity transform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# This function displays the effects of the deformable transformation on a grid of points by scaling the\n",
    "# initial displacements (either of control points for BSpline or the deformation field itself). It does\n",
    "# assume that all points are contained in the range(-2.5,-2.5), (2.5,2.5).\n",
    "#\n",
    "def display_displacement_scaling_effect(\n",
    "    s, original_x_mat, original_y_mat, tx, original_control_point_displacements\n",
    "):\n",
    "    if tx.GetDimension() != 2:\n",
    "        raise ValueError(\"display_displacement_scaling_effect only works in 2D\")\n",
    "\n",
    "    plt.scatter(\n",
    "        original_x_mat,\n",
    "        original_y_mat,\n",
    "        marker=\"o\",\n",
    "        color=\"blue\",\n",
    "        label=\"original points\",\n",
    "    )\n",
    "    pointsX = []\n",
    "    pointsY = []\n",
    "    tx.SetParameters(s * original_control_point_displacements)\n",
    "\n",
    "    for index, value in np.ndenumerate(original_x_mat):\n",
    "        px, py = tx.TransformPoint((value, original_y_mat[index]))\n",
    "        pointsX.append(px)\n",
    "        pointsY.append(py)\n",
    "\n",
    "    plt.scatter(pointsX, pointsY, marker=\"^\", color=\"red\", label=\"transformed points\")\n",
    "    plt.legend(loc=(0.25, 1.01))\n",
    "    plt.xlim((-2.5, 2.5))\n",
    "    plt.ylim((-2.5, 2.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BSpline\n",
    "Using a sparse set of control points to control a free form deformation. Note that the order of parameters to the transformation is $[x_0\\ldots x_N,y_0\\ldots y_N, z_0\\ldots z_N]$ for $N$ control points.\n",
    "\n",
    "\n",
    "To configure this transformation type we need to specify its bounded domain and the parameters for the control points, the incremental shifts from original grid positions. This can either be done explicitly by specifying the set of parameters defining the domain and control point parameters one by one or by using a set of images that encode all of this information in a more compact manner.\n",
    "\n",
    "The next two code cells illustrate these two options."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the transformation (when working with images it is easier to use the BSplineTransformInitializer function\n",
    "# or its object oriented counterpart BSplineTransformInitializerFilter).\n",
    "dimension = 2\n",
    "spline_order = 3\n",
    "direction_matrix_row_major = [1.0, 0.0, 0.0, 1.0]  # identity, mesh is axis aligned\n",
    "origin = [-1.0, -1.0]\n",
    "domain_physical_dimensions = [2, 2]\n",
    "mesh_size = [4, 3]\n",
    "\n",
    "bspline = sitk.BSplineTransform(dimension, spline_order)\n",
    "bspline.SetTransformDomainOrigin(origin)\n",
    "bspline.SetTransformDomainDirection(direction_matrix_row_major)\n",
    "bspline.SetTransformDomainPhysicalDimensions(domain_physical_dimensions)\n",
    "bspline.SetTransformDomainMeshSize(mesh_size)\n",
    "\n",
    "# Random displacement of the control points, specifying the x and y\n",
    "# displacements separately allows us to play with these parameters,\n",
    "# just multiply one of them with zero to see the effect.\n",
    "x_displacement = np.random.random(len(bspline.GetParameters()) // 2)\n",
    "y_displacement = np.random.random(len(bspline.GetParameters()) // 2)\n",
    "original_control_point_displacements = np.concatenate([x_displacement, y_displacement])\n",
    "bspline.SetParameters(original_control_point_displacements)\n",
    "\n",
    "# Apply the BSpline transformation to a grid of points\n",
    "# starting the point set exactly at the origin of the BSpline mesh is problematic as\n",
    "# these points are considered outside the transformation's domain,\n",
    "# remove epsilon below and see what happens.\n",
    "numSamplesX = 10\n",
    "numSamplesY = 20\n",
    "\n",
    "coordsX = np.linspace(\n",
    "    origin[0] + np.finfo(float).eps,\n",
    "    origin[0] + domain_physical_dimensions[0],\n",
    "    numSamplesX,\n",
    ")\n",
    "coordsY = np.linspace(\n",
    "    origin[1] + np.finfo(float).eps,\n",
    "    origin[1] + domain_physical_dimensions[1],\n",
    "    numSamplesY,\n",
    ")\n",
    "XX, YY = np.meshgrid(coordsX, coordsY)\n",
    "\n",
    "interact(\n",
    "    display_displacement_scaling_effect,\n",
    "    s=(-1.5, 1.5),\n",
    "    original_x_mat=fixed(XX),\n",
    "    original_y_mat=fixed(YY),\n",
    "    tx=fixed(bspline),\n",
    "    original_control_point_displacements=fixed(original_control_point_displacements),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next define the same BSpline transformation using a set of coefficient images. Note that to compare the parameter values for the two transformations we need to scale the values in the new transformation using the scale value used in the GUI above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_expected": "invalid syntax"
   },
   "outputs": [],
   "source": [
    "control_point_number = [sz+spline_order for sz in mesh_size]\n",
    "num_parameters_per_axis = np.prod(control_point_number)\n",
    "\n",
    "coefficient_images = []\n",
    "for i in range(dimension):\n",
    "    coefficient_image = sitk.GetImageFromArray((original_control_point_displacements[i*num_parameters_per_axis:(i+1)*num_parameters_per_axis]).reshape(control_point_number))\n",
    "    coefficient_image.SetOrigin(origin)\n",
    "    coefficient_image.SetSpacing([sz/(cp-1) for cp,sz in zip(control_point_number, domain_physical_dimensions)])\n",
    "    coefficient_image.SetDirection(direction_matrix_row_major)\n",
    "    coefficient_images.append(coefficient_image)\n",
    "\n",
    "bspline2 = sitk.BSplineTransform(coefficient_images, spline_order)\n",
    "\n",
    "# Note that the scale value is left intentionally blank: set the scale value based on the slider value in the GUI above.\n",
    "# You will get an error when executing the cell if a value is not provided.\n",
    "scale_factor_from_gui = \n",
    "print(np.array(bspline.GetParameters()) - np.array(bspline2.GetParameters())*scale_factor_from_gui)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DisplacementField\n",
    "\n",
    "A dense set of vectors representing the displacement inside the given domain. The most generic representation of a transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the displacement field.\n",
    "\n",
    "# When working with images the safer thing to do is use the image based constructor,\n",
    "# sitk.DisplacementFieldTransform(my_image), all the fixed parameters will be set correctly and the displacement\n",
    "# field is initialized using the vectors stored in the image. SimpleITK requires that the image's pixel type be\n",
    "# sitk.sitkVectorFloat64.\n",
    "displacement = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10, 20]\n",
    "field_origin = [-1.0, -1.0]\n",
    "field_spacing = [2.0 / 9.0, 2.0 / 19.0]\n",
    "field_direction = [1, 0, 0, 1]  # direction cosine matrix (row major order)\n",
    "\n",
    "# Concatenate all the information into a single list\n",
    "displacement.SetFixedParameters(\n",
    "    field_size + field_origin + field_spacing + field_direction\n",
    ")\n",
    "# Set the interpolator, either sitkLinear which is default or nearest neighbor\n",
    "displacement.SetInterpolator(sitk.sitkNearestNeighbor)\n",
    "\n",
    "originalDisplacements = np.random.random(len(displacement.GetParameters()))\n",
    "displacement.SetParameters(originalDisplacements)\n",
    "\n",
    "coordsX = np.linspace(\n",
    "    field_origin[0],\n",
    "    field_origin[0] + (field_size[0] - 1) * field_spacing[0],\n",
    "    field_size[0],\n",
    ")\n",
    "coordsY = np.linspace(\n",
    "    field_origin[1],\n",
    "    field_origin[1] + (field_size[1] - 1) * field_spacing[1],\n",
    "    field_size[1],\n",
    ")\n",
    "XX, YY = np.meshgrid(coordsX, coordsY)\n",
    "\n",
    "interact(\n",
    "    display_displacement_scaling_effect,\n",
    "    s=(-1.5, 1.5),\n",
    "    original_x_mat=fixed(XX),\n",
    "    original_y_mat=fixed(YY),\n",
    "    tx=fixed(displacement),\n",
    "    original_control_point_displacements=fixed(originalDisplacements),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Displacement field transform created from an image. Remember that SimpleITK will clear the image you provide, as shown in the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "displacement_image = sitk.Image([64, 64], sitk.sitkVectorFloat64)\n",
    "# The only point that has any displacement is (0,0)\n",
    "displacement = (0.5, 0.5)\n",
    "displacement_image[0, 0] = displacement\n",
    "\n",
    "print(\"Original displacement image size: \" + point2str(displacement_image.GetSize()))\n",
    "\n",
    "displacement_field_transform = sitk.DisplacementFieldTransform(displacement_image)\n",
    "\n",
    "print(\n",
    "    \"After using the image to create a transform, displacement image size: \"\n",
    "    + point2str(displacement_image.GetSize())\n",
    ")\n",
    "\n",
    "# Check that the displacement field transform does what we expect.\n",
    "print(\n",
    "    f\"Expected result: {str(displacement)}\\nActual result:{displacement_field_transform.TransformPoint((0,0))}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inverting bounded transforms\n",
    "\n",
    "In SimpleITK we cannot directly invert a BSpline transform. Luckily there are several ways to invert a displacement field transform, and **all** transformations can be readily converted to a displacement field. Note though that representing a transformation as a deformation field is an approximation of the original transformation where representation consistency depends on the smoothness of the original transformation and the sampling rate (spacing) of the deformation field.\n",
    "\n",
    "The relevant classes are listed below.\n",
    "* [TransformToDisplacementFieldFilter](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1TransformToDisplacementFieldFilter.html)\n",
    "\n",
    "Options for inverting displacement field:\n",
    "* [InvertDisplacementFieldImageFilter](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1InvertDisplacementFieldImageFilter.html)\n",
    "* [InverseDisplacementFieldImageFilter](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1InverseDisplacementFieldImageFilter.html)\n",
    "* [IterativeInverseDisplacementFieldImageFilter](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1IterativeInverseDisplacementFieldImageFilter.html)\n",
    "\n",
    "**Note**: The methods used to invert a displacement field make assumptions with respect to the function smoothness and continuity and will fail to yield a valid result if these assumptions are not met. For example an affine transformation representing a reflection is invertible, but inverting a deformation field representing this transformation will not yield the desired inverse.  \n",
    "\n",
    "In the next cell we invert the BSpline transform we worked with above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the BSpline transform to a displacement field\n",
    "physical_size = bspline.GetTransformDomainPhysicalDimensions()\n",
    "# The deformation field spacing affects the accuracy of the transform approximation,\n",
    "# so we set it here to 0.1mm in all directions.\n",
    "output_spacing = [0.1] * bspline.GetDimension()\n",
    "output_size = [\n",
    "    int(phys_sz / spc + 1) for phys_sz, spc in zip(physical_size, output_spacing)\n",
    "]\n",
    "displacement_field_transform = sitk.DisplacementFieldTransform(\n",
    "    sitk.TransformToDisplacementField(\n",
    "        bspline,\n",
    "        outputPixelType=sitk.sitkVectorFloat64,\n",
    "        size=output_size,\n",
    "        outputOrigin=bspline.GetTransformDomainOrigin(),\n",
    "        outputSpacing=output_spacing,\n",
    "        outputDirection=bspline.GetTransformDomainDirection(),\n",
    "    )\n",
    ")\n",
    "\n",
    "# Arbitrary point to evaluate the consistency of the two representations.\n",
    "# Change the value for the \"output_spacing\" above to evaluate its effect\n",
    "# on the transformation representation consistency.\n",
    "pnt = [0.4, -0.2]\n",
    "original_transformed = np.array(bspline.TransformPoint(pnt))\n",
    "secondary_transformed = np.array(displacement_field_transform.TransformPoint(pnt))\n",
    "print(f\"Original transformation result: {original_transformed}\")\n",
    "print(f\"Deformaiton field transformation result: {secondary_transformed}\")\n",
    "print(\n",
    "    f\"Difference between transformed points is: {np.linalg.norm(original_transformed - secondary_transformed)}\"\n",
    ")\n",
    "\n",
    "# Invert a displacement field transform\n",
    "displacement_image = displacement_field_transform.GetDisplacementField()\n",
    "bspline_inverse_displacement = sitk.DisplacementFieldTransform(\n",
    "    sitk.InvertDisplacementField(\n",
    "        displacement_image,\n",
    "        maximumNumberOfIterations=20,\n",
    "        maxErrorToleranceThreshold=0.01,\n",
    "        meanErrorToleranceThreshold=0.0001,\n",
    "        enforceBoundaryCondition=True,\n",
    "    )\n",
    ")\n",
    "\n",
    "\n",
    "# Transform the point using the original BSpline transformation and then back\n",
    "# via the displacement field inverse.\n",
    "there_and_back = np.array(\n",
    "    bspline_inverse_displacement.TransformPoint(bspline.TransformPoint(pnt))\n",
    ")\n",
    "print(f\"Original point: {pnt}\")\n",
    "print(f\"There and back point: {there_and_back}\")\n",
    "print(\n",
    "    f\"Difference between original and there-and-back points: {np.linalg.norm(pnt - there_and_back)}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CompositeTransform\n",
    "\n",
    "This class represents a composition of transformations, multiple transformations applied one after the other. \n",
    "\n",
    "The choice of whether to use a composite transformation or compose transformations on your own has subtle differences in the registration framework.\n",
    "\n",
    "Below we represent the composite transformation $T_{affine}(T_{rigid}(x))$ in two ways: (1) use a composite transformation to contain the two; (2) combine the two into a single affine transformation. We can use both as initial transforms (SetInitialTransform) for the registration framework (ImageRegistrationMethod). The difference is that in the former case the optimized parameters belong to the rigid transformation and in the later they belong to the combined-affine transformation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a composite transformation: T_affine(T_rigid(x)).\n",
    "rigid_center = (100, 100, 100)\n",
    "theta_x = 0.0\n",
    "theta_y = 0.0\n",
    "theta_z = np.pi / 2.0\n",
    "rigid_translation = (1, 2, 3)\n",
    "rigid_euler = sitk.Euler3DTransform(\n",
    "    rigid_center, theta_x, theta_y, theta_z, rigid_translation\n",
    ")\n",
    "\n",
    "affine_center = (20, 20, 20)\n",
    "affine_translation = (5, 6, 7)\n",
    "\n",
    "# Matrix is represented as a vector-like data in row major order.\n",
    "affine_matrix = np.random.random(9)\n",
    "affine = sitk.AffineTransform(affine_matrix, affine_translation, affine_center)\n",
    "\n",
    "# Using the composite transformation we just add them in (stack based, first in - last applied).\n",
    "composite_transform = sitk.CompositeTransform(affine)\n",
    "composite_transform.AddTransform(rigid_euler)\n",
    "\n",
    "# Create a single transform manually. this is a recipe for compositing any two global transformations\n",
    "# into an affine transformation, T_0(T_1(x)):\n",
    "# A = A0*A1\n",
    "# c = c1\n",
    "# t = A0*[t1+c1-c0] + t0+c0-c1\n",
    "A0 = np.asarray(affine.GetMatrix()).reshape(3, 3)\n",
    "c0 = np.asarray(affine.GetCenter())\n",
    "t0 = np.asarray(affine.GetTranslation())\n",
    "\n",
    "A1 = np.asarray(rigid_euler.GetMatrix()).reshape(3, 3)\n",
    "c1 = np.asarray(rigid_euler.GetCenter())\n",
    "t1 = np.asarray(rigid_euler.GetTranslation())\n",
    "\n",
    "combined_mat = np.dot(A0, A1)\n",
    "combined_center = c1\n",
    "combined_translation = np.dot(A0, t1 + c1 - c0) + t0 + c0 - c1\n",
    "combined_affine = sitk.AffineTransform(\n",
    "    combined_mat.flatten(), combined_translation, combined_center\n",
    ")\n",
    "\n",
    "# Check if the two transformations are equivalent.\n",
    "print(\"Apply the two transformations to the same point cloud:\")\n",
    "print(\"\\t\", end=\"\")\n",
    "print_transformation_differences(composite_transform, combined_affine)\n",
    "\n",
    "print(\"Transform parameters:\")\n",
    "print(\"\\tComposite transform: \" + point2str(composite_transform.GetParameters(), 2))\n",
    "print(\"\\tCombined affine: \" + point2str(combined_affine.GetParameters(), 2))\n",
    "\n",
    "print(\"Fixed parameters:\")\n",
    "print(\n",
    "    \"\\tComposite transform: \" + point2str(composite_transform.GetFixedParameters(), 2)\n",
    ")\n",
    "print(\"\\tCombined affine: \" + point2str(combined_affine.GetFixedParameters(), 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When a composite transformation is comprised of global transformations we can combine all of them into a single affine transformation, this is a generalization of the operation shown in the cell above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def composite2affine(composite_transform, result_center=None):\n",
    "    \"\"\"\n",
    "    Combine all of the composite transformation's contents to form an equivalent affine transformation.\n",
    "    Args:\n",
    "        composite_transform (SimpleITK.CompositeTransform): Input composite transform which contains only\n",
    "                                                            global transformations, possibly nested.\n",
    "        result_center (tuple,list): The desired center parameter for the resulting affine transformation.\n",
    "                                    If None, then set to [0,...]. This can be any arbitrary value, as it is\n",
    "                                    possible to change the transform center without changing the transformation\n",
    "                                    effect.\n",
    "    Returns:\n",
    "        SimpleITK.AffineTransform: Affine transformation that has the same effect as the input composite_transform.\n",
    "    \"\"\"\n",
    "    # Flatten the copy of the composite transform, so no nested composites.\n",
    "    flattened_composite_transform = sitk.CompositeTransform(composite_transform)\n",
    "    flattened_composite_transform.FlattenTransform()\n",
    "    tx_dim = flattened_composite_transform.GetDimension()\n",
    "    A = np.eye(tx_dim)\n",
    "    c = np.zeros(tx_dim) if result_center is None else result_center\n",
    "    t = np.zeros(tx_dim)\n",
    "    for i in range(flattened_composite_transform.GetNumberOfTransforms() - 1, -1, -1):\n",
    "        curr_tx = flattened_composite_transform.GetNthTransform(i).Downcast()\n",
    "        # The TranslationTransform interface is different from other\n",
    "        # global transformations.\n",
    "        if curr_tx.GetTransformEnum() == sitk.sitkTranslation:\n",
    "            A_curr = np.eye(tx_dim)\n",
    "            t_curr = np.asarray(curr_tx.GetOffset())\n",
    "            c_curr = np.zeros(tx_dim)\n",
    "        else:\n",
    "            A_curr = np.asarray(curr_tx.GetMatrix()).reshape(tx_dim, tx_dim)\n",
    "            c_curr = np.asarray(curr_tx.GetCenter())\n",
    "            # Some global transformations do not have a translation\n",
    "            # (e.g. ScaleTransform, VersorTransform)\n",
    "            get_translation = getattr(curr_tx, \"GetTranslation\", None)\n",
    "            if get_translation is not None:\n",
    "                t_curr = np.asarray(get_translation())\n",
    "            else:\n",
    "                t_curr = np.zeros(tx_dim)\n",
    "        A = np.dot(A_curr, A)\n",
    "        t = np.dot(A_curr, t + c - c_curr) + t_curr + c_curr - c\n",
    "\n",
    "    return sitk.AffineTransform(A.flatten(), t, c)\n",
    "\n",
    "\n",
    "# Create a nested composite transformation using the one from the\n",
    "# previous cell and add a scale and a translation.\n",
    "composite_transform.AddTransform(composite_transform)\n",
    "composite_transform.AddTransform(sitk.ScaleTransform(3, [1.2, 1.4, 2.0]))\n",
    "composite_transform.AddTransform(sitk.TranslationTransform(3, [1, 2, 3]))\n",
    "# Get the corresponding affine transformation\n",
    "simplified_composite = composite2affine(\n",
    "    composite_transform, result_center=[100, 200, 300]\n",
    ")\n",
    "\n",
    "# Check if the two transformations are equivalent.\n",
    "print(\"Apply the two transformations to the same point cloud:\")\n",
    "print(\"\\t\", end=\"\")\n",
    "print_transformation_differences(composite_transform, simplified_composite)\n",
    "\n",
    "print(\"Transform parameters:\")\n",
    "print(\"\\tComposite transform: \" + point2str(composite_transform.GetParameters(), 2))\n",
    "print(\"\\tCombined affine: \" + point2str(simplified_composite.GetParameters(), 2))\n",
    "\n",
    "print(\"Fixed parameters:\")\n",
    "print(\n",
    "    \"\\tComposite transform: \" + point2str(composite_transform.GetFixedParameters(), 2)\n",
    ")\n",
    "print(\"\\tCombined affine: \" + point2str(simplified_composite.GetFixedParameters(), 2))\n",
    "\n",
    "# Why doesn't the composite_transform seem to have fixed parameters?\n",
    "# The last, n'th, transformation in the composite_transform is a TranslationTransform and that has no fixed parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Composite transforms enable a combination of a global transformation with multiple local/bounded transformations. This is useful if we want to apply deformations only in regions that deform while other regions are only effected by the global transformation.\n",
    "\n",
    "The following code illustrates this, where the whole region is translated and subregions have different deformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Global transformation.\n",
    "translation = sitk.TranslationTransform(2, (1.0, 0.0))\n",
    "\n",
    "# Displacement in region 1.\n",
    "displacement1 = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10, 20]\n",
    "field_origin = [-1.0, -1.0]\n",
    "field_spacing = [2.0 / 9.0, 2.0 / 19.0]\n",
    "field_direction = [1, 0, 0, 1]  # direction cosine matrix (row major order)\n",
    "\n",
    "# Concatenate all the information into  a single list.\n",
    "displacement1.SetFixedParameters(\n",
    "    field_size + field_origin + field_spacing + field_direction\n",
    ")\n",
    "displacement1.SetParameters(np.ones(len(displacement1.GetParameters())))\n",
    "\n",
    "# Displacement in region 2.\n",
    "displacement2 = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10, 20]\n",
    "field_origin = [1.0, -3]\n",
    "field_spacing = [2.0 / 9.0, 2.0 / 19.0]\n",
    "field_direction = [1, 0, 0, 1]  # direction cosine matrix (row major order)\n",
    "\n",
    "# Concatenate all the information into a single list.\n",
    "displacement2.SetFixedParameters(\n",
    "    field_size + field_origin + field_spacing + field_direction\n",
    ")\n",
    "displacement2.SetParameters(-1.0 * np.ones(len(displacement2.GetParameters())))\n",
    "\n",
    "# Composite transform which applies the global and local transformations.\n",
    "composite = sitk.CompositeTransform([translation, displacement1, displacement2])\n",
    "\n",
    "# Apply the composite transformation to points in ([-1,-3],[3,1]) and\n",
    "# display the deformation using a quiver plot.\n",
    "\n",
    "# Generate points.\n",
    "numSamplesX = 10\n",
    "numSamplesY = 10\n",
    "coordsX = np.linspace(-1.0, 3.0, numSamplesX)\n",
    "coordsY = np.linspace(-3.0, 1.0, numSamplesY)\n",
    "XX, YY = np.meshgrid(coordsX, coordsY)\n",
    "\n",
    "# Transform points and compute deformation vectors.\n",
    "pointsX = np.zeros(XX.shape)\n",
    "pointsY = np.zeros(XX.shape)\n",
    "for index, value in np.ndenumerate(XX):\n",
    "    px, py = composite.TransformPoint((value, YY[index]))\n",
    "    pointsX[index] = px - value\n",
    "    pointsY[index] = py - YY[index]\n",
    "\n",
    "plt.quiver(XX, YY, pointsX, pointsY);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inverting Composite Transform\n",
    "\n",
    "When a `CompositeTransform` is:\n",
    "1. only comprised of global transformations, all we need to do is call its `GetInverse` method.\n",
    "2. comprised of both global and bounded transformations, the `GetInverse` method will fail because inverting the bounded transformations requires additional information which is not available as part of the transformation.\n",
    "    \n",
    "The next cell shows how to invert a `CompositeTransform` for the generic case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def invert_composite_transform(\n",
    "    original_transform, displacement_field_inverter, grid_spacing=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Invert the given CompositeTransform. Note that the original\n",
    "    transform is modified, flattened. We do not create a copy\n",
    "    of the original because of the large memory usage associated\n",
    "    with the bounded transformations. If the caller wants to retain\n",
    "    the original nested structure of the CompositeTransform it is up\n",
    "    to them to create a copy prior to calling this method.\n",
    "    Args:\n",
    "        original_transform: A CompositeTransform containing global transforms\n",
    "                            bounded transform and nested composite transforms.\n",
    "        displacement_field_inverter: Configured object for inverting a displacement\n",
    "                                     field. One of InvertDisplacementFieldImageFilter,\n",
    "                                     InverseDisplacementFieldImageFilter,\n",
    "                                     IterativeInverseDisplacementFieldImageFilter.\n",
    "        grid_spacing: The grid spacing to use for approximating internal BSplineTransforms.\n",
    "                      Finer grids provide better approximation at a cost of a larger\n",
    "                      memory footprint.\n",
    "    Return:\n",
    "        CompositeTransform which is the inverse of the given one.\n",
    "    \"\"\"\n",
    "    inverted_transform_list = []\n",
    "    original_transform.FlattenTransform()\n",
    "    for i in range(original_transform.GetNumberOfTransforms() - 1, -1, -1):\n",
    "        tx = original_transform.GetNthTransform(i)\n",
    "        ttype = tx.GetTransformEnum()\n",
    "        if ttype is sitk.sitkDisplacementField:\n",
    "            inverted_transform_list.append(\n",
    "                sitk.DisplacementFieldTransform(\n",
    "                    displacement_field_inverter.Execute(\n",
    "                        sitk.DisplacementFieldTransform(tx).GetDisplacementField()\n",
    "                    )\n",
    "                )\n",
    "            )\n",
    "        elif ttype is sitk.sitkBSplineTransform:\n",
    "            # Convert the BSpline transform to a displacement field and then invert that transform\n",
    "            physical_size = tx.GetTransformDomainPhysicalDimensions()\n",
    "            grid_size = [\n",
    "                int(phys_sz / spc + 1)\n",
    "                for phys_sz, spc in zip(physical_size, grid_spacing)\n",
    "            ]\n",
    "            displacement_field_image = sitk.TransformToDisplacementField(\n",
    "                tx,\n",
    "                outputPixelType=sitk.sitkVectorFloat64,\n",
    "                size=grid_size,\n",
    "                outputOrigin=tx.GetTransformDomainOrigin(),\n",
    "                outputSpacing=grid_spacing,\n",
    "                outputDirection=tx.GetTransformDomainDirection(),\n",
    "            )\n",
    "            inverted_transform_list.append(\n",
    "                sitk.DisplacementFieldTransform(\n",
    "                    displacement_field_inverter.Execute(displacement_field_image)\n",
    "                )\n",
    "            )\n",
    "        else:\n",
    "            inverted_transform_list.append(tx.GetInverse())\n",
    "    return sitk.CompositeTransform(inverted_transform_list)\n",
    "\n",
    "\n",
    "# inverting a CompositeTransform:\n",
    "# 1. Select the inversion algorithm and configure it (possibly use default configuration).\n",
    "# 2. Call the invert_composite_transform function.\n",
    "\n",
    "df_inverter = sitk.InvertDisplacementFieldImageFilter()\n",
    "df_inverter.SetMaximumNumberOfIterations(100)\n",
    "df_inverter.SetEnforceBoundaryCondition(True)\n",
    "\n",
    "composite_inverse = invert_composite_transform(composite, df_inverter)\n",
    "\n",
    "# display the inverse composite transform using a quiver plot\n",
    "pointsX = np.zeros(XX.shape)\n",
    "pointsY = np.zeros(XX.shape)\n",
    "for index, value in np.ndenumerate(XX):\n",
    "    px, py = composite_inverse.TransformPoint((value, YY[index]))\n",
    "    pointsX[index] = px - value\n",
    "    pointsY[index] = py - YY[index]\n",
    "\n",
    "plt.quiver(XX, YY, pointsX, pointsY);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transform\n",
    "\n",
    "This class represents a generic transform. Underneath the generic facade is one of the actual classes. To access the underlying class object we can call the `Downcast` method. While this provides us with the actual transform type, we don't know which of the concrete transformation types it is. To find the specific type we can query the transform to obtain its [TransformEnum](https://simpleitk.org/doxygen/latest/html/namespaceitk_1_1simple.html#a527cb966ed81d0bdc65999f4d2d4d852)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "anonymous_transform_type = sitk.Transform(sitk.TranslationTransform(2, (1.0, 0.0)))\n",
    "\n",
    "try:\n",
    "    print(anonymous_transform_type.GetOffset())\n",
    "except:\n",
    "    print(\"The generic transform does not have this method.\")\n",
    "\n",
    "actual_transform_type = anonymous_transform_type.Downcast()\n",
    "# Check that the actual transform type is indeed a translation before\n",
    "# calling a translation specific method.\n",
    "if actual_transform_type.GetTransformEnum() == sitk.sitkTranslation:\n",
    "    print(actual_transform_type.GetOffset())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing and Reading\n",
    "\n",
    "The SimpleITK.ReadTransform() returns a SimpleITK.Transform . The content of the file can be any of the SimpleITK transformations or a composite (set of transformations). \n",
    "\n",
    "The transformation file formats supported by SimpleITK include *.txt*, *.tfm*, *.xfm*, *.hdf* and *.mat*. The former three are ASCII based formats and are more appropriate for saving global domain transformations, which are also easily understood by a human reader due to their limited number of parameters. The later two, *.hdf* and *.mat*, are binary formats and more appropriate for saving bounded domain transformations as those have a large number of parameters which are better saved using a binary file, faster IO, and also not readily understood by a human reader.\n",
    "\n",
    "**Note**: Writing of nested composite transforms is not supported, you will need to \"flatten\" the transform before writing it to file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Create a 2D rigid transformation, write it to disk and read it back.\n",
    "basic_transform = sitk.Euler2DTransform()\n",
    "basic_transform.SetTranslation((1, 2))\n",
    "basic_transform.SetAngle(np.pi / 2)\n",
    "\n",
    "full_file_name = os.path.join(OUTPUT_DIR, \"euler2D.tfm\")\n",
    "\n",
    "sitk.WriteTransform(basic_transform, full_file_name)\n",
    "read_result = sitk.ReadTransform(full_file_name)\n",
    "print_transformation_differences(basic_transform, read_result)\n",
    "\n",
    "# Create a composite transform then write and read.\n",
    "displacement = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10, 20]\n",
    "field_origin = [-10.0, -100.0]\n",
    "field_spacing = [20.0 / (field_size[0] - 1), 200.0 / (field_size[1] - 1)]\n",
    "field_direction = [1, 0, 0, 1]  # direction cosine matrix (row major order)\n",
    "\n",
    "# Concatenate all the information into a single list.\n",
    "displacement.SetFixedParameters(\n",
    "    field_size + field_origin + field_spacing + field_direction\n",
    ")\n",
    "displacement.SetParameters(np.random.random(len(displacement.GetParameters())))\n",
    "\n",
    "composite_transform = sitk.CompositeTransform([basic_transform, displacement])\n",
    "\n",
    "full_file_name = os.path.join(OUTPUT_DIR, \"composite.tfm\")\n",
    "\n",
    "sitk.WriteTransform(composite_transform, full_file_name)\n",
    "read_result = sitk.ReadTransform(full_file_name)\n",
    "print_transformation_differences(composite_transform, read_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_translation = sitk.TranslationTransform(2, [1, 0])\n",
    "y_translation = sitk.TranslationTransform(2, [0, 1])\n",
    "# Create composite transform with the x_translation repeated 3 times\n",
    "composite_transform1 = sitk.CompositeTransform([x_translation] * 3)\n",
    "\n",
    "# Create a nested composite transform\n",
    "composite_transform = sitk.CompositeTransform([y_translation, composite_transform1])\n",
    "\n",
    "full_file_name = os.path.join(OUTPUT_DIR, \"composite.tfm\")\n",
    "\n",
    "# We cannot write nested composite transformations, will throw an exception so we\n",
    "# flatten it (unravel the nested part)\n",
    "try:\n",
    "    print(\n",
    "        f\"Nested composite transform contains {composite_transform.GetNumberOfTransforms()} transforms.\"\n",
    "    )\n",
    "    sitk.WriteTransform(composite_transform, full_file_name)\n",
    "except RuntimeError:\n",
    "    print(\"Failed writting nested composite transform.\")\n",
    "    composite_transform.FlattenTransform()\n",
    "    print(\n",
    "        f\"Nested composite transform after flattening contains {composite_transform.GetNumberOfTransforms()} transforms.\"\n",
    "    )\n",
    "    sitk.WriteTransform(composite_transform, full_file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the next cells we create a displacement field which has a nominal size for a CT/MR 512x512x100. We then save it using a text format and a binary format illustrating that IO is orders of magnitude faster when using the binary format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "displacement_field_transform = sitk.DisplacementFieldTransform(\n",
    "    sitk.GetImageFromArray(np.random.random([100, 512, 512, 3]))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit -r1 -n1\n",
    "sitk.WriteTransform(\n",
    "    displacement_field_transform, os.path.join(OUTPUT_DIR, \"deformation.tfm\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit -r1 -n1\n",
    "sitk.WriteTransform(\n",
    "    displacement_field_transform, os.path.join(OUTPUT_DIR, \"deformation.hdf\")\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
